import datetime
from glob import glob
import os
import time
import json
import math

from loguru import logger
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
import numpy as np
from nanoid import generate
import pandas as pd
import torch
import torch.nn as nn
import torch.utils.data as Data
from torch.utils.data import DataLoader

from lstm import LSTM
from dataset_loder import LoadData


class GPNetwort(object):
    def __init__(
        self,
        gp_id,
        space=1460,
        class_name="收盘",
        time_step=30,
        out_days=5,
        test_day_split=90,
        batch_size=32,
        hid_num=128,
    ) -> None:
        super().__init__()
        files = glob("datasets/*.csv")
        self.gp_id = gp_id
        self.status = None
        self.batch_size = batch_size
        self.hid_num = hid_num
        for i in files:
            if os.path.split(i)[1].split("_")[0] == self.gp_id:
                self.status = 200
                self.data_source = i
        if self.status is None:
            logger.error("股票id“{}”不存在，初始化失败", self.gp_id)
            raise ValueError("股票id不存在")
        self.space = space
        self.class_name = class_name
        self.time_step = time_step
        self.out_days = out_days
        self.test_day_split = test_day_split
        self.train_loader, self.test_loader, self.test_data = self.get_data(
            self.data_source, space, class_name, time_step, out_days, test_day_split
        )
        self.model = None

    def get_data(
        self, data_source, space, class_name, time_step, out_days, test_day_split
    ):
        train_data = LoadData(
            data_source,
            space=space,
            class_name=class_name,
            time_step=time_step,
            out_days=out_days,
            test_day_split=test_day_split,
            train_mode="train",
        )
        test_data = LoadData(
            data_source,
            space=space,
            class_name=class_name,
            time_step=time_step,
            out_days=out_days,
            test_day_split=test_day_split,
            train_mode="test",
        )
        train_loader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True)
        test_loader = DataLoader(test_data, batch_size=self.batch_size, shuffle=False)
        return train_loader, test_loader, test_data

    def get_model(self):
        """
        hid_num: 隐藏层参数
        """
        self.model = LSTM(
            input_num=1, hid_num=self.hid_num, layers_num=3, out_num=self.out_days
        )

    def train(self, epochs=30):
        """
        调用本函数开始训练
        增量训练请减少到10次以下
        """
        if not self.model:
            raise ValueError("请先load模型或者get模型！")
        loss_func = nn.MSELoss()
        optimizer = torch.optim.Adam(self.model.parameters())
        self.model.train()
        epoch_loss_change = []
        for epoch in range(epochs):
            epoch_loss = 0.0
            start_time = time.time()
            for data_ in self.train_loader:
                self.model.zero_grad()
                predict = self.model(data_["my_data"])
                loss = loss_func(predict, data_["my_label"].squeeze(2))
                epoch_loss += loss.item()
                loss.backward()
                optimizer.step()
            epoch_loss_change.append(epoch_loss / len(self.train_loader))
            end_time = time.time()
            logger.info(
                "Epoch: {}, avg_Loss: {}, Time: {} mins",
                epoch,
                epoch_loss / len(self.train_loader),
                (end_time - start_time) / 60,
            )

    def test(self):
        """
        用于测试网络训练效果的
        每取30天，输出1天预测值
        """
        if not self.model:
            raise ValueError("请先load模型或者get模型！")
        self.model.eval()
        with torch.no_grad():  # 关闭梯度
            total_loss = 0.0
            pre_flow = np.array([])
            real_flow = np.array([])
            for data_ in self.test_loader:
                pre_value = self.model(data_["my_data"])
                loss_func = nn.MSELoss()
                loss = loss_func(pre_value, data_["my_label"].squeeze(2))
                total_loss += loss.item()
                # print("真实值：",data_["my_label"])
                # 反归一化
                pre_value = LoadData.recoverd_data(
                    pre_value.detach().numpy(),
                    self.test_data.flow_norm[0].squeeze(1),  # max_data
                    self.test_data.flow_norm[1].squeeze(1),  # min_data
                )
                target_value = LoadData.recoverd_data(
                    data_["my_label"].detach().numpy(),
                    self.test_data.flow_norm[0].squeeze(1),
                    self.test_data.flow_norm[1].squeeze(1),
                )
                # print(self.test_data.test_data_time)

                pre_value = pre_value.flatten()
                target_value = target_value.flatten()
                pre_flow = np.concatenate([pre_flow, pre_value])
                real_flow = np.concatenate([real_flow, target_value])
                # print(real_flow,self.test_data.flow_norm)
            logger.info("test_avg_loss:{}", 10 * total_loss / len(self.test_loader))
            return pre_flow, real_flow  # 直接返回两个numpy数组

    def test_rnn(self, data, out_time=30):
        """
        用于真实的环境下的预测
        输入是最近的N天,输出out_time次预测值,比如输出30个预测值。
        为了准确度，请输入data数量至少与time_step相等
        data在这里应该是pandas读进来的没经过任何处理的csv数据列
        """
        if not self.model:
            raise ValueError("请先load模型或者get模型！")
        self.model.eval()
        normal, pre_datas = LoadData.pre_process_data(
            data.astype(np.float64).values[:, np.newaxis]
        )
        # 从pands拿最后self.time_step个数据

        pred_datas = pre_datas[-self.time_step :]  # 准备用来预测的数据
        times = math.ceil(out_time / self.out_days)  # 需要预测几次
        pred_datas_tensor = LoadData.to_tensor(pred_datas)
        pred_datas_tensor = pred_datas_tensor.unsqueeze(0)  # 添加一个batch
        res = np.array([])
        with torch.no_grad():
            for _ in range(times):
                out = self.model(pred_datas_tensor)
                pre_value = LoadData.recoverd_data(
                    out.detach().numpy(),
                    normal[0].squeeze(1),  # max_data
                    normal[1].squeeze(1),  # min_data
                )
                res = np.concatenate((res, pre_value.flatten()))
                out = out.unsqueeze(2)
                pred_datas_tensor = pred_datas_tensor[:, self.out_days :, :]
                pred_datas_tensor = torch.cat((pred_datas_tensor, out), 1)
        return res

    def save(self):
        nano = generate(size=10)
        torch.save(
            self.model.state_dict(),
            f"weights/{datetime.datetime.now().strftime('%Y-%m-%d')}_{self.gp_id}_{self.class_name}_{nano}.pt",
        )
        parameters = {}
        parameters["gp_id"] = self.gp_id
        parameters["batch_size"] = self.batch_size
        parameters["space"] = self.space
        parameters["class_name"] = self.class_name
        parameters["time_step"] = self.time_step
        parameters["out_days"] = self.out_days
        parameters["test_day_split"] = self.test_day_split
        parameters["hid_num"] = self.hid_num

        with open(
            f"weights/{datetime.datetime.now().strftime('%Y-%m-%d')}_{self.gp_id}_{self.class_name}_{nano}.json",
            "w",
            encoding="utf-8",
        ) as f:
            json.dump(parameters, f, ensure_ascii=False)

    def load(self, model_path):
        if not os.path.exists(model_path):
            raise FileNotFoundError
        time_flow = datetime.datetime.now() - datetime.strptime(
            os.path.split(model_path)[1].split("_")[0], "%Y-%m-%d"
        )
        if time_flow.day > 30:
            logger.warning(
                "距离上次训练已经超过30天，建议使用最新数据进行增量训练或者重新训练，以提升精度。"
            )
        with open(model_path.replace(".pth", ".json"), "r", encoding="utf-8") as f:
            parameters = json.load(f)
            self.gp_id = parameters["gp_id"]
            self.batch_size = parameters["batch_size"]
            self.space = parameters["space"]
            self.class_name = parameters["class_name"]
            self.time_step = parameters["time_step"]
            self.out_days = parameters["out_days"]
            self.test_day_split = parameters["test_day_split"]
            self.hid_num = parameters["hid_num"]
        self.get_model()
        self.model.load_state_dict(torch.load(model_path))


if __name__ == "__main__":
    # 训练
    gp = GPNetwort(gp_id="600000", out_days=1)
    gp.get_model()
    gp.train(epochs=30)
    gp.save()
    
    # 测试
    gp.test()

    # 预测与验证
    data = pd.read_csv("datasets/600000_浦发银行.csv")
    out = gp.test_rnn(data["收盘"])



    
